import math
import copy
from dataclasses import dataclass
from typing import Optional, Tuple
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss

from transformers.models.t5.modeling_t5 import (T5ForConditionalGeneration,
                                                T5PreTrainedModel,
                                                T5Stack,
                                                T5Block)
from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, BaseModelOutput
from models.models.sampler import KnowledgeSampler
from models.models.kfm import KFM
from torch_geometric.utils import to_dense_batch

class T5ForKnowledgeAugmentedGeneration(nn.Module):
    def __init__(self, args, entity_embeddings):
        super().__init__()
        self.generator = CustomT5ForConditionalGeneration.from_pretrained("t5-small", args=args)
        self.knowledge_sampler = KnowledgeSampler(args, entity_embeddings)
        self.k = args.num_samples # Top-k for Marginalization (Training)
        self.discriminator = nn.Bilinear(args.hidden_size, args.hidden_size, 1)
        self.args = args
        if args.use_contrastive:
            self.r2l = nn.Linear(args.hidden_size, args.hidden_size)
            self.g2l = nn.Linear(args.hidden_size, args.hidden_size)
            self.batch_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
            self.self_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    # https://github.com/huggingface/transformers/blob/e68c3756fea7c811d02b8470539ae17ec3ec0e71/src/transformers/models/rag/modeling_rag.py#L1222
    def marginalize(self, seq_logits, doc_scores, n_docs):
        seq_logprobs = nn.functional.log_softmax(seq_logits, dim=-1).view(
            seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1)
        )
        doc_logprobs = torch.log_softmax(doc_scores, dim=1)
        log_prob_sum = seq_logprobs + doc_logprobs.unsqueeze(-1).unsqueeze(-1)
        return torch.logsumexp(log_prob_sum, dim=1)

    def retrieval_loss(self, scores, labels, batch):
        labels, mask = to_dense_batch(labels, batch)
        scores, _ = to_dense_batch(scores, batch)

        weight = (labels.sum(-1) > 0).float()
        labels = labels.float()
        loss_fn = nn.BCEWithLogitsLoss(reduction='none')

        loss = loss_fn(scores, labels)
        loss = (loss * mask).sum(-1) / (mask.sum(-1) + 1e-10)
        loss = (loss * weight).sum(-1) / (weight.sum(-1) + 1e-10)
        return loss

    def contrastive_loss(
        self,
        encoder_mask,
        decoder_mask,
        mention_positions,
        encoder_hidden_states,
        decoder_hidden_states,
        temperature=0.07,
        base_temperature=0.07,
        k=1,
    ):
        device = self.args.device

        graph_len = mention_positions.shape[1]
        graph_mask = mention_positions != -1
        keep_case = (graph_mask.sum(-1) > 0) # indicate the batch with graph

        graph_hidden_states = encoder_hidden_states[:,:graph_len]

        dec_emb = decoder_hidden_states * decoder_mask.unsqueeze(2)
        dec_mean_emb = torch.sum(dec_emb, 1) / (decoder_mask.sum(-1).unsqueeze(1).float() + 1e-10)
        dec_mean_emb = dec_mean_emb * (decoder_mask.sum(-1) > 0).unsqueeze(-1)

        graph_emb = graph_hidden_states * graph_mask.unsqueeze(2)
        graph_mean_emb = torch.sum(graph_emb, 1) / (graph_mask.sum(-1).unsqueeze(1).float() + 1e-10)
        graph_mean_emb = graph_mean_emb * (graph_mask.sum(-1) > 0).unsqueeze(-1)
        # graph_feature = graph_mean_emb[keep_case]

        enc_feature = self.g2l(graph_mean_emb)
        dec_feature = self.r2l(dec_mean_emb)

        enc_feature = F.normalize(enc_feature, p=2.0, dim=-1) # L2 Norm
        dec_feature = F.normalize(dec_feature, p=2.0, dim=-1) # L2 Norm

        # B x K x D -> K x B x D
        graph = enc_feature.view(enc_feature.shape[0] // k, k, -1).transpose(0, 1)
        dec = dec_feature.view(dec_feature.shape[0] // k, k, -1).transpose(0, 1)
        keep_mask = keep_case.view(keep_case.shape[0] // k, k).transpose(0, 1)
        # K x B x B  (src: graph, tgt: decoder)
        batch_logits = torch.matmul(graph, dec.transpose(1, 2)) * self.batch_scale.exp()

        labels = torch.stack([torch.arange(batch_logits.shape[1], device=device) for _ in range(k)])
        ce = nn.CrossEntropyLoss(reduction='none')
        cont_loss = (ce(batch_logits, labels) * keep_mask).sum() / (keep_mask.sum().float() + 1e-10)

        # K x B x B  (src: decoder, tgt: graph)
        _batch_logits = batch_logits.transpose(1, 2)
        cont_loss += (ce(_batch_logits, labels) * keep_mask).sum() / (keep_mask.sum().float() + 1e-10)

        cont_loss /= 2.0 # Divide by two
        if self.args.use_self_contrastive:
            graph = enc_feature.view(enc_feature.shape[0] // k, k, -1)
            dec = dec_feature.view(dec_feature.shape[0] // k, k, -1)
            keep_mask = keep_case.view(keep_case.shape[0] // k, k)

            self_logits = torch.matmul(graph, dec.transpose(1, 2)) * self.self_scale.exp()
            self_labels = torch.stack([torch.arange(self_logits.shape[1], device=device) \
                                       for _ in range(enc_feature.shape[0] // k)])
            self_cont_loss = (ce(self_logits, self_labels) * keep_mask).sum() / (keep_mask.sum().float() + 1e-10)

            _self_logits = self_logits.transpose(1, 2)
            self_cont_loss += (ce(_self_logits, self_labels) * keep_mask).sum() / (keep_mask.sum().float() + 1e-10)

            cont_loss += self_cont_loss * 0.5
        return cont_loss

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None, # Belows are the additional inputs
        mention_positions=None,
        nodes=None,
        edge_index=None,
        edge_attr=None,
        edge_labels=None,
        graph_batch=None,
        local_indicator=None,
    ):
        k = 1 if labels is None else self.k
        input_ids, attention_mask, \
        probs, scores, \
        graph_inputs, augment_graph_inputs, \
        edge_batch, full_scores = self.knowledge_sampler(
            input_ids,
            attention_mask,
            mention_positions,
            nodes,
            edge_index,
            edge_attr,
            graph_batch,
            local_indicator,
            k,
        )
        
        decoder_input_ids, decoder_attention_mask = self.knowledge_sampler.expand_decoder_input(
            decoder_input_ids, decoder_attention_mask, k,
        )

        outputs = self.generator(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            labels=None,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=False,
            graph_inputs=graph_inputs,
        )
        lm_logits = outputs[0]

        if self.args.use_contrastive:
            contrastive_loss = self.contrastive_loss(
                attention_mask,
                decoder_attention_mask,
                graph_inputs["mention_positions"],
                outputs[1],
                outputs[2],
                k=k,
            )
        else:
            contrastive_loss = torch.tensor(0.0).to(self.args.device)

        retrieval_loss = self.retrieval_loss(full_scores, edge_labels, edge_batch)

        def _mask_pads(ll, ignore_index=0):
            pad_mask = labels.eq(ignore_index)
            if pad_mask.any():
                ll.masked_fill_(pad_mask, 0.0)
            return ll.squeeze(-1)

        if labels is not None:
            # -log(y|x) Loss Computation from REALM w/ Marginalization
            scores = scores.view(lm_logits.shape[0] // k, k)
            rag_logprobs = self.marginalize(lm_logits, scores, k)
            labels = labels.unsqueeze(-1)

            ll = rag_logprobs.gather(dim=-1, index=labels)
            ll = _mask_pads(ll, ignore_index=0)
            ll = ll.sum(1)
            nll_loss = -ll
            loss = nll_loss.mean()
            loss += contrastive_loss
            if not self.args.unsupervised:
                loss += retrieval_loss
            # return ((loss,) + outputs)
            return (({'total_loss': loss, 
                      'gold_loss': torch.tensor(0.0), 
                      'contrastive_loss': contrastive_loss,
                      'retrieval_loss': retrieval_loss},) + outputs)
        return outputs

class CustomT5ForConditionalGeneration(T5ForConditionalGeneration):
    def __init__(self, config, args):
        super().__init__(config)
        self.args = args
        # self.kfm = KFM(args)

        encoder_config = copy.deepcopy(config)
        encoder_config.is_decoder = False
        encoder_config.use_cache = False
        encoder_config.is_encoder_decoder = False

        self.encoder = CustomT5Stack(encoder_config, self.shared, args)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        decoder_head_mask=None,
        cross_attn_head_mask=None,
        encoder_outputs=None,
        past_key_values=None,
        inputs_embeds=None,
        decoder_inputs_embeds=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        graph_inputs=None,
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Encode if needed (training, first prediction pass)
        if encoder_outputs is None:
            # Convert encoder inputs in embeddings if needed
            encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
                inputs_embeds=inputs_embeds,
                head_mask=head_mask,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                graph_inputs=graph_inputs,
            )
        elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
            encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs[0],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
            )

        hidden_states = encoder_outputs[0]

        # KFM here
        # hidden_states = self.kfm(hidden_states, **graph_inputs)

        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)

        if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
            # get decoder inputs from shifting lm labels to the right
            decoder_input_ids = self._shift_right(labels)

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.decoder.first_device)
            hidden_states = hidden_states.to(self.decoder.first_device)
            if decoder_input_ids is not None:
                decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
            if attention_mask is not None:
                attention_mask = attention_mask.to(self.decoder.first_device)
            if decoder_attention_mask is not None:
                decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)

        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
            inputs_embeds=decoder_inputs_embeds,
            past_key_values=past_key_values,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = decoder_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.encoder.first_device)
            self.lm_head = self.lm_head.to(self.encoder.first_device)
            sequence_output = sequence_output.to(self.lm_head.weight.device)

        if self.config.tie_word_embeddings:
            # Rescale output before projecting on vocab
            # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
            sequence_output = sequence_output * (self.model_dim ** -0.5)

        lm_logits = self.lm_head(sequence_output)

        loss = None
        if labels is not None:
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
            # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666

        if not return_dict:
            output = (lm_logits, hidden_states, sequence_output) \
                     + decoder_outputs[1:] + encoder_outputs
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqLMOutput(
            loss=loss,
            logits=lm_logits,
            past_key_values=decoder_outputs.past_key_values,
            decoder_hidden_states=decoder_outputs.hidden_states,
            decoder_attentions=decoder_outputs.attentions,
            cross_attentions=decoder_outputs.cross_attentions,
            encoder_last_hidden_state=encoder_outputs.last_hidden_state,
            encoder_hidden_states=encoder_outputs.hidden_states,
            encoder_attentions=encoder_outputs.attentions,
        )
    
    def prepare_inputs_for_generation(
        self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs
    ):
        graph_inputs = kwargs["graph_inputs"]
        # cut decoder_input_ids if past is used
        if past is not None:
            input_ids = input_ids[:, -1:]

        return_dict = {
            "decoder_input_ids": input_ids,
            "past_key_values": past,
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask,
            "use_cache": use_cache,
            "graph_inputs": graph_inputs,
        }
        # return_dict.update(adapter_inputs)
        return return_dict


class CustomT5Stack(T5Stack):
    def __init__(self, config, embed_tokens=None, args=None):
        super().__init__(config, embed_tokens)
        self.kfm = KFM(args)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        inputs_embeds=None,
        head_mask=None,
        cross_attn_head_mask=None,
        past_key_values=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        graph_inputs=None,
    ):
        if inputs_embeds is None:
            assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
            inputs_embeds = self.embed_tokens(input_ids)
            inputs_embeds = self.kfm(inputs_embeds, **graph_inputs)
            return super().forward(
                None,
                attention_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                inputs_embeds,
                head_mask,
                cross_attn_head_mask,
                past_key_values,
                use_cache,
                output_attentions,
                output_hidden_states,
                return_dict,
            )
        
        else:
            return super().forward(
                input_ids,
                attention_mask,
                encoder_hidden_states,
                encoder_attention_mask,
                inputs_embeds,
                head_mask,
                cross_attn_head_mask,
                past_key_values,
                use_cache,
                output_attentions,
                output_hidden_states,
                return_dict,
            )
